from functools import wraps
import os
from multiprocessing.dummy import Pool as ThreadPool
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

CAPACITY = 1000
ITEM_VALUE = 20


class Item():
    def __init__(self, weight, value, *args, **kwargs):
        self.weight = weight
        self.value = value


n = 40
items = [Item(np.random.randint(1, 100), np.random.randint(1, ITEM_VALUE))
         for _ in range(n)]

SUM_ITEM_VALUE = sum([item.value for item in items])


plt.ion()
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 unused import
# https://matplotlib.org/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html?highlight=bar3d#mpl_toolkits.mplot3d.axes3d.Axes3D.bar3d
# bar3d(self, x, y, z, dx, dy, dz, color=None, zsort='average',  *args, **kwargs)
plt.style.use('fivethirtyeight')
fig = plt.figure()
ax = fig.add_subplot(121, projection='3d')
ax2 = fig.add_subplot(122)
dp_value = "?"
dp_value_history = []
ant_value = "?"
ant_value_history = []
dp_ans = []
ant_ans = []


def plt_reset(func):
    @wraps(func)
    def warp(*args):
        ax.cla()
        ax2.cla()
        ax.axis('off')  # 去掉坐标轴
        ax.set_xlim(0, BLANK*n)
        ax.set_ylim(0, BLANK*n)
        ax.set_zlim(0, 20)
        plt.title(f"{dp_value} {dp_ans}\n{ant_value} {ant_ans}")
        if dp_value != "?":
            dp_value_history.append(dp_value)
            ax2.plot(list(range(len(dp_value_history))),
                     dp_value_history, color="red")
        if ant_value != "?":
            ant_value_history.append(ant_value)
            ax2.plot(list(range(len(ant_value_history))),
                     ant_value_history, color="gold")
            assert dp_value >= ant_value
        # if len(ant_value_history) > 10:
        #     del dp_value_history[0]
        #     del ant_value_history[0]
        return func(*args)
    return warp


'''
dp 解法
'''

dp = np.zeros([n+1, CAPACITY+1])

for idx, item in enumerate(items):
    for j in range(CAPACITY+1):
        i = idx + 1
        if j < item.weight:
            dp[i][j] = dp[i-1][j]
        else:
            dp[i][j] = max(dp[i-1][j], dp[i-1][int(j-item.weight)]+item.value)

print(dp[n][CAPACITY])
j = CAPACITY
check = 0
for i in range(n, 0, -1):
    if dp[i-1][j] == dp[i][j]:
        pass
    else:
        idx = i-1
        dp_ans = [idx] + dp_ans
        j -= items[idx].weight
        check += items[idx].value
print(dp_ans)
dp_value = dp[n][CAPACITY]
assert check == dp_value
BLANK = 15
@plt_reset
def show_knapsack_select(dp_ans):
    for idx, item in enumerate(items):
        x, y, z = idx*BLANK, 0, 0
        dx = dy = np.sqrt(item.weight)
        dz = item.value
        if idx in dp_ans:
            color = "red"
        else:
            color = "blue"
        ax.bar3d(x, y, z, dx, dy, dz,  color=color)
        plt.pause(0.1)


show_knapsack_select(dp_ans)


class Ant:
    def __init__(self, *args, **kwargs):
        self.weight = 0
        self.value = 0
        self.u = (0, 0)  # 所在点，所在状态
        self.tabu = []  # 用于记录路径


const = 1
tau = np.ones([n+1, 2, 2]) * const

eta = np.ones([n+1, 2, 2])  # 先验概率
# for i in range(n):
#     eta[i][0][0] = eta[i][1][0] = 1
#     idx = i
#     eta[i][0][1] = eta[i][1][1] = np.clip(
#         (items[idx].value / items[idx].weight), 0.1, 2)


def Pk_ij(ant_k, i, j, alpha=1, beta=1):
    '''
    状态转移公式
    ===
    return 转移过去的点v
    '''
    allowed = [0]
    idx = i
    if ant_k.weight + items[idx].weight <= CAPACITY:
        allowed.append(1)
    top = np.power(tau[i, j, allowed], alpha) * \
        np.power(eta[i, j, allowed], beta)
    bottom = np.sum(np.power(tau[i, j, allowed], alpha)
                    * np.power(eta[i, j, allowed], beta))
    # v = allowed[np.argmax(top/bottom)]
    np.random.choice(len(allowed), p=top/bottom)
    v = allowed[np.random.choice(len(allowed), p=top/bottom)]
    return v


N_cmax = 1000
m = 100
Q = 1  # 蚁周模型更新的信息素大小 = Q / ant[k].value
ants = [Ant() for i in range(m)]


@plt_reset
def show_ant_system(ant_k):
    u_x, u_y, u_z = -BLANK, 0, 0
    u = (0, 0)
    for (idx, item), v in zip(enumerate(items), ant_k.tabu):
        v_x, v_y, v_z = u_x+BLANK, u_y, u_z
        dx = dy = np.sqrt(item.weight)
        dz = item.value

        i, j, _, k = *u, *v
        if k == 0:
            color = ("red", "blue")
        else:
            color = ("blue", "red")
        ax.bar3d(v_x, v_y, v_z, dx, dy, 0.1,  color=color[0])
        ax.bar3d(v_x, v_y+(BLANK*3), v_z, dx, dy,
                 dz,  color=color[1])

        for j in range(2):
            for k in range(2):
                ax.plot([u_x, v_x], [u_y+j*(BLANK*3), v_y+k*(BLANK*3)], [u_z if j == 0 or idx == 0 else items[idx-1].value, v_z if k == 0 else items[idx].value],
                        color="black", linewidth=tau[i][j][k])
        u = v
        u_x, u_y, u_z = v_x, v_y, v_z
    plt.pause(1e-4)


ant_max_value = 0
for N_c in (range(N_cmax)):
    for ant_k in ants:
        ant_k.weight = 0
        ant_k.value = 0
        ant_k.u = (0, 0)
        ant_k.tabu = []
        for i in range(n):
            i, j = ant_k.u
            v = Pk_ij(ant_k, i, j)
            idx = i
            if v == 1:
                ant_k.weight += items[idx].weight
                ant_k.value += items[idx].value
            ant_k.u = (i+1, v)
            ant_k.tabu.append(ant_k.u)
    '''
    蚁周模型更新信息素
    '''
    def update_tau(p=0.1):
        global tau
        global ant_max_value
        #     tau_ij(t+n)的n已经可以忽略
        # 公式6.5
        delta_tau = np.zeros([n+1, 2, 2])
        for ant_k in ants:
            delta = 1
            if ant_k.value < ant_max_value:
                break
            else:
                ant_max_value = ant_k.value
                delta += 2
            u = (0, 0)
            for v in ant_k.tabu:
                i, j, _, k = *u, *v
                delta_tau[i, j, k] = max(delta, delta_tau[i, j, k])
                u = v
        # 公式6.4
        # delta_tau = np.clip(delta_tau, 0, 1)
        tau = (1-p)*tau+delta_tau
        tau = np.clip(tau, 0.1, 3)
        print(np.max(tau))

    update_tau()
    ant_value = ants[0].value
    ant_ans = []

    for u in ants[0].tabu:
        i, j = u
        if j:
            ant_ans.append(i)
    show_ant_system(ants[0])
